from __future__ import annotations

import time

# import torch
import copy
# import torch.nn as nn
# from torch import optim
# from torch.distributions import Normal
from typing import Dict, Union
import tensorflow as tf
import tensorflow_probability as tfp
from rsl_rl.modules.mlp import MLPModel
from rsl_rl.tf_version.dpg_tf import AbstractDPG
from rsl_rl.env import VecEnv
from rsl_rl.modules.network import Network
from rsl_rl.storage.storage import Dataset
from rsl_rl.modules.actor_critic import ActorCritic
from rsl_rl.storage.rollout_storage import RolloutStorage
from rsl_rl.storage.rollout_storage_old import RolloutStorageOld


class DDPGTF(AbstractDPG):

    def __init__(
            self,
            env: VecEnv,
            actor_lr: float = 1e-4,
            # actor_lr: float = 3e-4,
            critic_lr: float = 1e-3,
            **kwargs,
    ) -> None:
        super().__init__(env, **kwargs)
        self._critic_input_size = 60

        self.actor = MLPModel(shape_input=48, shape_output=12, name="actor", output_activation="tanh").model
        self.critic = MLPModel(shape_input=60, shape_output=1, name="critic", output_activation=None).model
        self.target_actor = MLPModel(shape_input=48, shape_output=12, name="target_actor", output_activation="tanh").model
        self.target_critic = MLPModel(shape_input=60, shape_output=1, name="target_critic",
                                      output_activation=None).model

        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)
        self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)
        self.std_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

        self.trans_list = []
        self.transition = RolloutStorage.Transition()
        init_noise_std = 0.1
        self.lam = 0.95
        self.std = tf.Variable(initial_value=init_noise_std * tf.ones(self._action_size), trainable=True, name='std')

        self._register_serializable(
            "actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer"
        )

        # self.to(self.device)

    def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
        pass

    def eval_mode(self) -> DDPGTF:
        super().eval_mode()

        self.actor.eval()
        self.critic.eval()
        self.target_actor.eval()
        self.target_critic.eval()

        return self

    def act(self, obs, critic_obs):

        return self.nn_act(obs)

    def to(self, device: str) -> DDPGTF:
        """Transfers agent parameters to device."""
        super().to(device)

        self.actor.to(device)
        self.critic.to(device)
        self.target_actor.to(device)
        self.target_critic.to(device)

        return self

    def train_mode(self) -> DDPGTF:
        super().train_mode()

        self.actor.train()
        self.critic.train()
        self.target_actor.train()
        self.target_critic.train()

        return self

    def update(self, dataset: Dataset) -> [float, float]:
        super().update(dataset)

        total_actor_loss = []
        total_critic_loss = []

        # print(f"batch_size: {self._batch_size}....................................................")
        for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
            # print(f"idx: {idx}")
            # print(f"batch: {batch}")
            with tf.GradientTape(persistent=True) as tape:
                actor_obs = batch["actor_observations"]
                critic_obs = batch["critic_observations"]
                actions = batch["actions"]
                rewards = batch["rewards"]
                actor_next_obs = batch["next_actor_observations"]
                critic_next_obs = batch["next_critic_observations"]
                dones = tf.cast(batch["dones"], dtype=tf.float32)

                target_actor_prediction = self._process_actions(self.target_actor(actor_next_obs))
                target_critic_prediction = self.target_critic(
                    self._critic_input(critic_next_obs, target_actor_prediction)
                )
                # print(f"target_actor:{target_actor_prediction}")
                target = rewards + self._discount_factor * (1 - dones) * target_critic_prediction
                prediction = self.critic(self._critic_input(critic_obs, actions))
                critic_loss = tf.reduce_mean(tf.square(prediction - target))

                # Optimize Critic
                gradients = tape.gradient(critic_loss, self.critic.trainable_variables)

                self.critic_optimizer.apply_gradients(zip(gradients, self.critic.trainable_variables))
                evaluation = self.critic(
                    self._critic_input(critic_obs, self._process_actions(self.actor(actor_obs)))
                )
                actor_loss = -tf.reduce_mean(evaluation)


                gradients = tape.gradient(actor_loss, self.actor.trainable_variables)
                self.actor_optimizer.apply_gradients(zip(gradients, self.actor.trainable_variables))

                self._update_target(self.actor, self.target_actor)
                self._update_target(self.critic, self.target_critic)

            total_actor_loss.append(actor_loss.numpy())
            total_critic_loss.append(critic_loss.numpy())

        # stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()}
        return tf.reduce_mean(total_critic_loss), tf.reduce_mean(total_actor_loss)

    def process_env_step2(self, prev_obs, obs, actions, rewards, dones, infos):
        res = {
            'actor_observations': tf.identity(prev_obs),
            'critic_observations': tf.identity(prev_obs),
            'actions': tf.identity(actions),
            'rewards': tf.identity(rewards),
            'next_actor_observations': tf.identity(obs),
            'next_critic_observations': tf.identity(obs),
            'dones': tf.identity(dones),
            'timeouts': tf.identity(infos['time_outs'])
        }

        return res

    def process_env_step(self, rewards, dones, infos):

        self.transition.rewards = tf.identity(rewards)
        self.transition.dones = dones

        # Bootstrapping on time outs
        if 'time_outs' in infos:
            # print(f"un: {torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1).shape}")
            # self.transition.rewards += self.gamma * torch.squeeze(
            #     self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
            self.transition.rewards += self.gamma * tf.squeeze(
                self.transition.values * tf.expand_dims(tf.convert_to_tensor(infos['time_outs']), axis=1),
                axis=1)

        # Record the transition
        self.storage.add_transitions(self.transition)
        self.transition.clear()
        self.reset(dones)

    def compute_returns(self, last_critic_obs):
        last_values = self.evaluate(last_critic_obs).detach()
        self.storage.compute_returns(last_values, self.gamma, self.lam)

    def update_distribution(self, observations):
        mean = self.actor(observations)
        # print(f"self.std: {self.std}")
        # self.distribution = Normal(mean, mean * 0. + self.std)
        self.distribution = tfp.distributions.Normal(loc=mean, scale=mean * 0 + self.std)

    def nn_act(self, observations, **kwargs):
        self.update_distribution(observations)
        return self.distribution.sample()

    def get_actions_log_prob(self, actions):
        return self.distribution.log_prob(actions).sum(dim=-1)

    def act_inference(self, observations):
        actions_mean = self.actor(observations)
        return actions_mean

    def evaluate(self, critic_observations, **kwargs):
        value = self.critic(critic_observations)
        return value

    @property
    def action_mean(self):
        return self.distribution.mean

    @property
    def action_std(self):
        return self.distribution.stddev

    def reset(self, dones=None):
        pass
